"""
Extended Lindblad Model Implementation
For: A Computational Quantum Model for Periodontal Low‑Level Laser Therapy
Version: 3.0
"""

import numpy as np
from qutip import *

class ExtendedTLSModel:
    """
    Extended Two-Level System model with separate relaxation and dephasing.
    Implements the corrected Lindblad master equation.
    """
    
    def __init__(self, Gamma=1.0):
        """
        Initialize model with characteristic rate Gamma (MHz).
        Gamma is used only for unit conversion.
        """
        self.Gamma = Gamma  # 1 MHz, for unit conversion only
        
        # Define basis states
        self.ground = basis(2, 0)  # |g⟩
        self.excited = basis(2, 1)  # |e⟩
        
        # Define operators
        self.sigma_x = sigmax()  # Pauli X
        self.sigma_z = sigmaz()  # Pauli Z
        self.sigma_m = sigmam()  # σ₋ = |g⟩⟨e|
        self.sigma_p = sigmap()  # σ₊ = |e⟩⟨g|
        
        # Define projection operators
        self.P_e = self.excited * self.excited.dag()  # |e⟩⟨e|
    
    def hamiltonian(self, Omega):
        """
        Hamiltonian in rotating frame (semiclassical approximation).
        H = (ħΩ/2) σ_x
        """
        return 0.5 * Omega * self.sigma_x
    
    def collapse_operators(self, gamma_1, gamma_phi):
        """
        Return collapse operators for Lindblad equation:
        - γ₁: relaxation rate (energy dissipation)
        - γ_φ: pure dephasing rate (phase randomization)
        """
        c_ops = []
        if gamma_1 > 0:
            c_ops.append(np.sqrt(gamma_1) * self.sigma_m)
        if gamma_phi > 0:
            c_ops.append(np.sqrt(gamma_phi) * self.sigma_z)
        return c_ops
    
    def steady_state_solution(self, Omega, gamma_1, gamma_phi):
        """
        Analytical solution for steady-state excited population.
        P_e = Ω² / (γ₁² + 2Ω²)
        """
        if gamma_1**2 + 2 * Omega**2 == 0:
            return 0.0
        return Omega**2 / (gamma_1**2 + 2 * Omega**2)
    
    def quantum_jump_rate(self, P_e, gamma_1):
        """
        Quantum jump rate J = γ₁ P_e
        Represents physical excitation flux.
        """
        return gamma_1 * P_e
    
    def coherence_lifetime(self, gamma_1, gamma_phi):
        """
        Coherence lifetime τ = 1/γ₂ where γ₂ = γ₁/2 + γ_φ
        """
        gamma_2 = 0.5 * gamma_1 + gamma_phi
        if gamma_2 == 0:
            return float('inf')
        return 1.0 / gamma_2
    
    def simulate_dynamics(self, Omega, gamma_1, gamma_phi, 
                         initial_state=None, t_max=100, n_points=1000):
        """
        Simulate time evolution of the system.
        
        Parameters:
        - Omega: Rabi frequency (MHz)
        - gamma_1: relaxation rate (MHz)
        - gamma_phi: dephasing rate (MHz)
        - initial_state: initial density matrix (default: ground state)
        - t_max: maximum time in µs
        - n_points: number of time points
        
        Returns:
        - times: array of time points (µs)
        - P_e_t: excited population vs time
        - coherence_t: magnitude of coherence |ρ_ge(t)|
        """
        # Default to ground state
        if initial_state is None:
            initial_state = self.ground
        
        # Time array
        times = np.linspace(0, t_max, n_points)
        
        # Hamiltonian and collapse operators
        H = self.hamiltonian(Omega)
        c_ops = self.collapse_operators(gamma_1, gamma_phi)
        
        # Expectation operators
        e_ops = [self.P_e, self.sigma_x]
        
        # Solve master equation
        result = mesolve(H, initial_state, times, c_ops, e_ops)
        
        # Extract results
        P_e_t = result.expect[0]  # Excited population
        coherence_t = np.abs(result.expect[1])  # |⟨σ_x⟩| ≈ 2|ρ_ge|
        
        return times, P_e_t, coherence_t
    
    def calculate_parameter_sweep(self, Omega_values, gamma_phi_values, 
                                 gamma_1=0.001):
        """
        Calculate results for parameter matrix.
        
        Returns:
        - dict with P_e, J, and τ for all parameter combinations
        """
        results = {
            'Omega': [],
            'gamma_phi': [],
            'P_e': [],
            'J': [],
            'tau': []
        }
        
        for Omega in Omega_values:
            for gamma_phi in gamma_phi_values:
                P_e = self.steady_state_solution(Omega, gamma_1, gamma_phi)
                J = self.quantum_jump_rate(P_e, gamma_1)
                tau = self.coherence_lifetime(gamma_1, gamma_phi)
                
                results['Omega'].append(Omega)
                results['gamma_phi'].append(gamma_phi)
                results['P_e'].append(P_e)
                results['J'].append(J)
                results['tau'].append(tau)
        
        return results

# Example usage
if __name__ == "__main__":
    # Initialize model
    model = ExtendedTLSModel(Gamma=1.0)
    
    # Test parameters from manuscript
    Omega_test = 0.20  # MHz
    gamma_1_test = 0.001  # MHz
    gamma_phi_test = 0.10  # MHz
    
    # Calculate steady-state properties
    P_e = model.steady_state_solution(Omega_test, gamma_1_test, gamma_phi_test)
    J = model.quantum_jump_rate(P_e, gamma_1_test)
    tau = model.coherence_lifetime(gamma_1_test, gamma_phi_test)
    
    print(f"Ω = {Omega_test} MHz")
    print(f"γ₁ = {gamma_1_test} MHz")
    print(f"γ_φ = {gamma_phi_test} MHz")
    print(f"P_e = {P_e:.4f}")
    print(f"J = {J:.6f} MHz")
    print(f"τ = {tau:.2f} µs")
    
    # Calculate γ₂
    gamma_2 = 0.5 * gamma_1_test + gamma_phi_test
    print(f"γ₂ = {gamma_2:.4f} MHz")
    print(f"1/γ₂ = {1/gamma_2:.2f} µs")